import torch
from torch import nn
import torch.nn.functional as F

from .CausalLayer import SCM
from .AttentionLayer import Attention

#定义一个先验网络 用于捕捉seeker情绪信息 其话语之间的信息是流动的
class PriorNet(nn.Module):
    def __init__(self, 
                 utter_dim, 
                 latent_dim,
                 emotion_dim=None,
                 prior_type='GRU',
                 mu_type='share',
                 var_type='share',
                 activation = nn.ReLU(inplace=True),
                 dropout_prob =0.1,
                 ):
        super(PriorNet, self).__init__()
        
        self.utter_dim = utter_dim
        self.prior_type = prior_type
        self.mu_type = mu_type
        self.var_type = var_type
        self.emotion_dim = utter_dim
        
        self.latent_dim = latent_dim
        self.dropout = dropout_prob
        
        self.SCM_model = SCM(in_dim=self.latent_dim, hidden_dim=self.latent_dim, hidden_num=3) 
        self.layernorm = nn.LayerNorm(self.latent_dim)
        self.attention = Attention(self.latent_dim)
        
        if prior_type == "GRU":
            self.emotion_feature_fusion = nn.Sequential(
                nn.Linear(self.emotion_dim + self.utter_dim, self.latent_dim // 2),
                nn.ReLU(),
                nn.Linear(self.latent_dim // 2, self.latent_dim)
            )

            self.cognitive_prior_eps_sec_GRU  = nn.GRUCell(self.latent_dim*2, self.latent_dim // 3) #
            self.behavior_prior_eps_sec_GRU  = nn.GRUCell(self.latent_dim*2, self.latent_dim // 3)
            self.emotion_prior_eps_sec_GRU = nn.GRUCell(self.latent_dim*2, self.latent_dim // 3)
            
        if self.mu_type == 'share':
            self.prior_mu_FC = nn.Linear(self.latent_dim, self.latent_dim)
            self.prior_logvar_FC = nn.Linear(self.latent_dim, self.latent_dim)

    # The following core implementation of the code has been removed, and the full code will be released upon the paper's acceptance.
    def forward(self, emotion_cur, seeker_cur, z_pr_last, z_po_last, mask=None):
        pass
        #多个隐藏变量拼接在一起 hidden
        return full_z_sec_cur, full_z_mu, full_z_logvar, full_eps_mu, full_eps_logvar, causal_loss

